import json
from torch.utils.data import Dataset
import random

class MQUAKE(Dataset):
    def __init__(self, dataset_name, dataset_dir, edit_num, seed_num):
        if dataset_name in ['CF-3k', 'CF-9k', 'CF-3151', 'CF-6334', 'T', 'T-old', 'CF-3k-old']:
            self.dataset_name = dataset_name
        else:
            raise ValueError("Dataset name %s unknown." % dataset_name)

        with open(dataset_dir + f'datasets/MQuAKE-Remastered-{dataset_name}.json', 'r') as f:
            self.dataset = json.load(f)
            
        if dataset_name == 'CF-6334':
            
            processed_dataset = []
            self.rand_list = []
            for d in self.dataset:
                labels = d['6334_split'][str(edit_num)]

                if 'test_edited_unique' in labels:
                    processed_dataset.append(d)
                    self.rand_list.append(d['case_id'])
                elif 'test_edited' in labels:
                    processed_dataset.append(d)
                    self.rand_list.append(d['case_id'])
                if 'test_unedited' in labels:
                    processed_dataset.append(d)
            self.dataset = processed_dataset
        else:

            if not edit_num:
                self.rand_list = []
            else:
                random.seed(seed_num)
                sample_from = [d['case_id'] for d in self.dataset]
                self.rand_list = random.sample(sample_from, edit_num)
        self.length = len(self.dataset)
    
    def __len__(self):
        return self.length
    
    def __getitem__(self, idx):
        if idx < 0 or idx >= self.length:
            raise IndexError("Index out of range.")
        return self.dataset[idx]
    
    def get_length(self):
        return self.length
    
    def get_dataset(self):
        return self.dataset
    
    def get_randlist(self):
        return self.rand_list
    
    @staticmethod
    def check_answer(edit_flag, instance, ans):
        if ans is None:
            return False
        # Define answer and answer_alias keys based on edit_flag
        answer = "answer"
        answer_alias = "answer_alias"
        if edit_flag:
            answer = "new_" + answer
            answer_alias = "new_" + answer_alias
        
        # Convert the answer and ans to upper case
        ans_upper = ans.upper()
        instance_answer_upper = instance[answer].upper()
        
        # Convert each alias to upper case for comparison
        instance_answer_alias_upper = [alias.upper() for alias in instance[answer_alias]]
        
        # Return true if ans matches the answer or any of the aliases
        return ans_upper == instance_answer_upper or ans_upper in instance_answer_alias_upper

    @staticmethod
    def verify_subquestion_path(prompt, d, edit_flag):
        """
        Copied from https://github.com/Hengrui-Gu/PokeMQA/blob/main/PokeMQA-turbo_n_edited.py#L50
        """
        if edit_flag:
            path = d['new_single_hops']
        else:
            path = d['single_hops']
            
        # Checking the correctness of reasoning path i.e. Calculating Hop-Acc
        # subquestion verification
        if not len(prompt.strip().split('Subquestion:')) == len(path) + 1:
            return False
        
        # reasoning path verification
        sub_prompt = prompt.strip().split('Subquestion:')
        for idx in range(1, len(sub_prompt)):
        
            inter_ans = sub_prompt[idx].strip().split(': ')[-1]
            # print(inter_ans)
            if inter_ans != path[idx - 1]["answer"] and inter_ans not in path[idx - 1]["answer_alias"]:
                return False
    
        return True


    def get_edits_without_contamination(self, rand_list, problem_case):
        """
        Inputs:
          dataset: the dataset of interest
          rand_list: a list of caseid of edited cases
          problem_case: one multi-hop case that we want to get the set of edits that wouldn't contaminate it
          edit_flag: a boolean (True if problem_case is an edited case, False otherwise).
                Note this affects the correct path of this instance (resulting in different edits that would contaminate it)
        Outputs:
          nl_facts: a list of natural language edits. e.g. "John Milton is a citizen of Spain"
          triple_labeled: a list of edits in triples of text. e.g. "(John Milton, {} is a citizen of, Spain)"
          triple_ids: similar to above but in id form. E.g. "(Q79759, P27, Q29)", where Q79759, P27, Q29 are ids of entity
          case_index: the "caseid-1" (used for list index accessing) of the case that the j-th edit are in.

        NOTE: the returned values may contain duplicate edits (since an edit may come from distinct multi-hop cases).

        """
    
        edit_flag = problem_case['case_id'] in rand_list
    
        triples_name = 'new_triples' if edit_flag else 'triples'
        correct_path = problem_case['orig'][triples_name]
    
        nl_facts = []  # a list of natural language edits. e.g. "John Milton is a citizen of Spain"
        triple_labeled = []  # a list of edits in triples of text. e.g. "(John Milton, {} is a citizen of, Spain)"
        triple_ids = []  # similar to above but in id form. E.g. "(Q79759, P27, Q29)", where Q79759, P27, Q29 are ids of
        # entity or relation.
    
        case_index = []  # corresponding case index (starts from 0 for list accessing) of the edit
    
        for d in self.dataset:
            if d['case_id'] not in rand_list:
                continue
            # want to check if d will contaminate problem_case:
            for edit, edit_extra_info in zip(d['orig']['edit_triples'], d['requested_rewrite']):
                contam_flag = False
                if any((edit[0] == p[0] and edit[1] == p[1] and edit[2] != p[2]) for p in correct_path):
                    # if the edit is the same subject and relation but different answer to a specific hop ->
                    # contamination
                    
                    contam_flag = True
            
                # add this edit to the edit bank:
                if not contam_flag:
                    nl_facts.append(
                        f'{edit_extra_info["prompt"].format(edit_extra_info["subject"])} {edit_extra_info["target_new"]["str"]}')
                    triple_labeled.append(tuple(
                        [edit_extra_info['subject'], edit_extra_info['prompt'], edit_extra_info["target_new"]["str"]]))
                    triple_ids.append(edit)
                    case_index.append(d['case_id'] - 1)
    
        return nl_facts, triple_labeled, triple_ids, case_index
